{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## install requirements"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from torch_geometric.nn import GATConv\n",
    "import networkx as nx\n",
    "import numpy as np\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Create sample data for testing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_sample_data():\n",
    "    # Create a small network with 10 nodes and 15 edges\n",
    "    num_nodes = 10\n",
    "    num_edges = 15\n",
    "    \n",
    "    # Node features (random for demonstration)\n",
    "    x = torch.randn(num_nodes, 16)  # 16 features per node\n",
    "    \n",
    "    # Random edges\n",
    "    edge_index = torch.randint(0, num_nodes, (2, num_edges))\n",
    "    \n",
    "    # Edge features\n",
    "    edge_attr = torch.randn(num_edges, 8)  # 8 features per edge\n",
    "    \n",
    "    # Labels (for supervised learning)\n",
    "    y = torch.randint(0, 2, (num_nodes,)).float()\n",
    "    \n",
    "    return x, edge_index, edge_attr, y\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Graph Construction\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "class NetworkGraph:\n",
    "    def construct_graph(self, network_data):\n",
    "        nodes = self.create_nodes(network_data)\n",
    "        edges = self.create_edges(network_data)\n",
    "        \n",
    "        for node in nodes:\n",
    "            node.features = self.extract_node_features(node)\n",
    "        \n",
    "        for edge in edges:\n",
    "            edge.features = self.extract_edge_features(edge)\n",
    "        \n",
    "        return nodes, edges\n",
    "    \n",
    "    def create_nodes(self, network_data):\n",
    "        # Simplified example\n",
    "        return [{'id': i} for i in range(len(network_data))]\n",
    "    \n",
    "    def create_edges(self, network_data):\n",
    "        # Simplified example\n",
    "        return [{'source': i, 'target': i+1} for i in range(len(network_data)-1)]\n",
    "    \n",
    "    def extract_node_features(self, node):\n",
    "        # Simplified feature extraction\n",
    "        return torch.randn(16)  # 16-dimensional feature vector\n",
    "    \n",
    "    def extract_edge_features(self, edge):\n",
    "        # Simplified feature extraction\n",
    "        return torch.randn(8)  # 8-dimensional feature vector\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## GNN Model Architecture"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "class APTDetectionModel(nn.Module):\n",
    "    def __init__(self, in_features, hidden_features, num_layers):\n",
    "        super().__init__()\n",
    "        self.gat_layers = nn.ModuleList([\n",
    "            GATConv(\n",
    "                in_features if i == 0 else hidden_features,\n",
    "                hidden_features\n",
    "            ) for i in range(num_layers)\n",
    "        ])\n",
    "        self.gru = nn.GRU(hidden_features, hidden_features)\n",
    "        self.output = nn.Linear(hidden_features, 1)\n",
    "    \n",
    "    def forward(self, x, edge_index, edge_attr):\n",
    "        for gat in self.gat_layers:\n",
    "            x = F.relu(gat(x, edge_index, edge_attr))\n",
    "        x, _ = self.gru(x.unsqueeze(0))\n",
    "        return self.output(x.squeeze(0)).squeeze(-1)\n",
    "    \n",
    "    def get_attention_weights(self, node):\n",
    "        # Simplified attention weight extraction\n",
    "        return torch.randn(10)  # Random weights for demonstration\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Training Function"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train_model(model, labeled_data, unlabeled_data, num_epochs=10):\n",
    "    optimizer = torch.optim.Adam(model.parameters())\n",
    "    \n",
    "    for epoch in range(num_epochs):\n",
    "        model.train()\n",
    "        \n",
    "        # Supervised learning\n",
    "        optimizer.zero_grad()\n",
    "        x, edge_index, edge_attr, y = labeled_data\n",
    "        out = model(x, edge_index, edge_attr)\n",
    "        loss = F.binary_cross_entropy_with_logits(out, y)\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        \n",
    "        # Print training progress\n",
    "        if epoch % 2 == 0:\n",
    "            print(f\"Epoch {epoch}, Loss: {loss.item():.4f}\")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## APT Detection Function"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def detect_apts(model, graph, threshold=0.5):\n",
    "    model.eval()\n",
    "    with torch.no_grad():\n",
    "        x, edge_index, edge_attr, _ = graph\n",
    "        anomaly_scores = torch.sigmoid(model(x, edge_index, edge_attr))\n",
    "        suspicious_nodes = (anomaly_scores > threshold).nonzero().flatten()\n",
    "        \n",
    "        results = []\n",
    "        for node in suspicious_nodes:\n",
    "            attention_weights = model.get_attention_weights(node)\n",
    "            results.append({\n",
    "                'node': node.item(),\n",
    "                'score': anomaly_scores[node].item(),\n",
    "                'attention': attention_weights\n",
    "            })\n",
    "        \n",
    "        return results\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Test the implementation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training model...\n",
      "Epoch 0, Loss: 0.6991\n",
      "Epoch 2, Loss: 0.6949\n",
      "Epoch 4, Loss: 0.6909\n",
      "Epoch 6, Loss: 0.6869\n",
      "Epoch 8, Loss: 0.6827\n",
      "\n",
      "Detecting APTs...\n",
      "\n",
      "Detection Results:\n",
      "\n",
      "Node 0: Anomaly Score = 0.5238\n",
      "Node 1: Anomaly Score = 0.5263\n",
      "Node 2: Anomaly Score = 0.5223\n",
      "Node 3: Anomaly Score = 0.5004\n",
      "Node 4: Anomaly Score = 0.5336\n",
      "Node 5: Anomaly Score = 0.5058\n",
      "Node 7: Anomaly Score = 0.5324\n",
      "Node 8: Anomaly Score = 0.5373\n",
      "Node 9: Anomaly Score = 0.5334\n"
     ]
    }
   ],
   "source": [
    "if __name__ == \"__main__\":\n",
    "    # Create sample data\n",
    "    x, edge_index, edge_attr, y = create_sample_data()\n",
    "    \n",
    "    # Initialize model\n",
    "    model = APTDetectionModel(in_features=16, hidden_features=32, num_layers=2)\n",
    "    \n",
    "    # Train model\n",
    "    print(\"Training model...\")\n",
    "    train_model(model, (x, edge_index, edge_attr, y), None)\n",
    "    \n",
    "    # Detect APTs\n",
    "    print(\"\\nDetecting APTs...\")\n",
    "    results = detect_apts(model, (x, edge_index, edge_attr, y))\n",
    "    \n",
    "    # Print results\n",
    "    print(\"\\nDetection Results:\\n\")\n",
    "    for result in results:\n",
    "        print(f\"Node {result['node']}: Anomaly Score = {result['score']:.4f}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "applied-graph-nn-book",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
